import os
import math
import argparse
from typing import List, Union
from tqdm import tqdm
from omegaconf import ListConfig
import imageio

import torch
import numpy as np
from einops import rearrange
import torchvision.transforms as TT


from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint_ref
from sat import mpu

from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
from PIL import Image

def read_from_cli():
    cnt = 0
    try:
        while True:
            x = input('Please input English text (Ctrl-D quit): ')
            yield x.strip(), cnt
            cnt += 1
    except EOFError as e:
        pass


def read_from_file(p, rank=0, world_size=1):
    with open(p, 'r') as fin:
        cnt = -1
        for l in fin:
            cnt += 1
            if cnt % world_size != rank:
                continue
            yield l.strip(), cnt

def read_from_file_reverse(p, rank=0, world_size=1):
    with open(p, 'r', encoding="utf-8") as fin:
        lines = fin.readlines()[::-1]
        for cnt, l in enumerate(lines):
            if cnt % world_size != rank:
                continue
            yield l.strip(), cnt

def get_unique_embedder_keys_from_conditioner(conditioner):
    return list(set([x.input_key for x in conditioner.embedders]))

def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
    batch = {}
    batch_uc = {}

    for key in keys:
        if key == "txt":
            batch["txt"] = (
                np.repeat([value_dict["prompt"]], repeats=math.prod(N))
                .reshape(N)
                .tolist()
            )
            batch_uc["txt"] = (
                np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
                .reshape(N)
                .tolist()
            )
        else:
            batch[key] = value_dict[key]
        
    if T is not None:
        batch["num_video_frames"] = T

    for key in batch.keys():
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    return batch, batch_uc

def save_video_as_grid_and_mp4(
    video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None
):
    os.makedirs(save_path, exist_ok=True)

    for i, vid in enumerate(video_batch):

        gif_frames = []
        for frame in vid:
            frame = rearrange(frame, "c h w -> h w c")
            frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
            gif_frames.append(frame)
        now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
        with imageio.get_writer(now_save_path, fps=fps) as writer:
            for frame in gif_frames:
                writer.append_data(frame)

def resize_for_rectangle_crop(arr, image_size, reshape_mode='random'):
    if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
        arr = resize(arr, size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
                     interpolation=InterpolationMode.BICUBIC)
    else:
        arr = resize(arr, size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
                     interpolation=InterpolationMode.BICUBIC)

    h, w = arr.shape[2], arr.shape[3]
    arr = arr.squeeze(0)

    delta_h = h - image_size[0]
    delta_w = w - image_size[1]

    if reshape_mode == 'random' or reshape_mode == 'none':
        top = np.random.randint(0, delta_h + 1)
        left = np.random.randint(0, delta_w + 1)
    elif reshape_mode == 'center':
        top, left = delta_h // 2, delta_w // 2
    else:
        raise NotImplementedError
    arr = TT.functional.crop(
        arr, top=top, left=left, height=image_size[0], width=image_size[1]
    )
    return arr

def sampling_main(args, model_cls):
    if isinstance(model_cls, type):
        model = get_model(args, model_cls)
    else:
        model = model_cls
    iteration = 1000
    load_checkpoint_ref(model, args, specific_iteration=iteration)
    model.eval()

    if args.input_type == 'cli':
        data_iter = read_from_cli()
    elif args.input_type == 'txt':
        rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
        print("rank and world_size", rank, world_size)
        data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
        #data_iter = read_from_file_reverse(args.input_file, rank=rank, world_size=world_size)
    else:
        raise NotImplementedError
    
    image_size = [480, 720]
    #image_size = [720, 1088]
    
    if args.image2video:
        chained_trainsforms = []
        # chained_trainsforms.append(TT.Resize(size=image_size, interpolation=1))
        chained_trainsforms.append(TT.ToTensor())
        transform = TT.Compose(chained_trainsforms)


    sample_func = model.sample
    T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
    num_samples = [1]
    force_uc_zero_embeddings = ['txt']
    with torch.no_grad():
        for text, cnt in tqdm(data_iter):
            if args.image2video:
                text, image_path = text.split('@@')
                if not os.path.exists(image_path): continue
                image = Image.open(image_path).convert('RGB')
                image = transform(image).unsqueeze(0).to('cuda')
                image = resize_for_rectangle_crop(image, image_size, reshape_mode='center').unsqueeze(0)
                image = image * 2.0 - 1.0
                image = image.unsqueeze(2).to(torch.bfloat16)
                image = model.encode_first_stage_ref(image, None)
                image = image.permute(0, 2, 1, 3, 4).contiguous()
                pad_shape = (image.shape[0], T-1, C, H // F, W // F)
                image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
            else:
                image = None
            print("rank:", rank, "start to process", text, cnt)
            # TODO: broadcast image2video
            value_dict = {
                'prompt': text,
                'negative_prompt': '',
                'num_frames': torch.tensor(T).unsqueeze(0),
            }

            batch, batch_uc = get_batch(
                get_unique_embedder_keys_from_conditioner(model.conditioner),
                value_dict,
                num_samples
            )
            for key in batch:
                if isinstance(batch[key], torch.Tensor):
                    print(key, batch[key].shape)
                elif isinstance(batch[key], list):
                    print(key, [len(l) for l in batch[key]])
                else:
                    print(key, batch[key])
            c, uc = model.conditioner.get_unconditional_conditioning(
                batch,
                batch_uc=batch_uc,
                force_uc_zero_embeddings=force_uc_zero_embeddings,
            )

            for k in c:
                if not k == "crossattn":
                    c[k], uc[k] = map(
                        lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
                    )

            c["concat"] = image
            # uc["concat"] = torch.zeros_like(image)
            uc["concat"] = image

            for index in range(args.batch_size):
                samples_z = sample_func(
                    c,
                    uc = uc,
                    batch_size = 1,
                    shape = (T, C, H // F, W // F),
                )
                samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
                if args.only_save_latents:
                    samples_z = 1.0 / model.scale_factor * samples_z
                    save_path = os.path.join(args.output_dir, str(cnt) + '_' + text.replace(' ', '_').replace('/', '')[:20], str(index))
                    os.makedirs(save_path, exist_ok=True)
                    torch.save(samples_z, os.path.join(save_path, 'latent.pt'))
                    with open(os.path.join(save_path, 'text.txt'), 'w') as f:
                        f.write(text)
                else:
                    samples_x = model.decode_first_stage(samples_z).to(torch.float32)
                    samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
                    samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
                    save_path = os.path.join(args.output_dir, str(iteration),image_path.split('/')[-1][:-4])
                    if mpu.get_model_parallel_rank() == 0:
                        save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)

if __name__ == '__main__':
    if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ:
        os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
        os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
        os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
    py_parser = argparse.ArgumentParser(add_help=False)
    known, args_list = py_parser.parse_known_args()
    args = get_args(args_list)
    args = argparse.Namespace(**vars(args), **vars(known))
    del args.deepspeed_config
    args.model_config.first_stage_config.params.cp_size = 1
    args.model_config.network_config.params.transformer_args.model_parallel_size = 1
    args.model_config.network_config.params.transformer_args.checkpoint_activations = False
    args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
    sampling_main(args, model_cls=SATVideoDiffusionEngine)
